# -*- coding: utf-8 -*-
"""
Created on Tue Dec  7 12:26:25 2021

@author: zhiyinan
"""
import numpy as np
# import matplotlib.pyplot as plt
# import tensorflow as tf
import Parameters1D_casadi as par 
from Model1D_casadi_crop import *
import van_Genuchten1D_casadi as vg 
# from scipy import integrate
from casadi import *
# import matplotlib.font_manager as font_manager
import time
# import pandas as pd
import random


# %% Define the original model for optimal 
# Nx = 1
totalDepth, axialNodes, axialDistance, totalNodes = par.spatialVariables_1D()  #Spatial parameters
samplingTime, samplingTimeInternal, internalTimeSteps, totTimeSteps = \
                                  par.temporalVariable_dataGen(25000*6*60)


soilPars = loamySoil()
def ODE(head, irrigAmount, cropCoeff, refEvap):
    origin = Richards1D_casadi(head, irrigAmount, cropCoeff, refEvap)
    error = DM.zeros(26)
    # for i in range(26):
    #     error[i] = np.random.choice([-1,1]) * 0.1*np.random.random_sample()*origin[i]
    return origin + error

#symbolic variables
# x = MX.sym('x', totalNodes)
# I = MX.sym('I')
# C = MX.sym('C')
# rE = MX.sym('rE')

# Fm = Function('Fm', [x, I, C, rE], [ODE(x, I, C, rE)])


# def RungeKutta_4(T, head, irrigAmount, cropCoeff, refEvap):
#     k_1 = ODE(head, irrigAmount, cropCoeff, refEvap)
#     x_2 = head + 0.5*T*k_1
#     k_2 = ODE(x_2, irrigAmount, cropCoeff, refEvap)
#     x_3 = head + 0.5*T*k_2
#     k_3 = ODE(x_3, irrigAmount, cropCoeff, refEvap)
#     x_4 = head + T*k_3
#     k_4 = ODE(x_4, irrigAmount, cropCoeff, refEvap)
    
#     x_k = head + (T/6.0)*(k_1 + 2*k_2 + 2*k_3 + k_4)
#     return x_k

##Creating a casadi function for the RK4 method
t = MX.sym('t')
x = MX.sym('x', totalNodes)
I = MX.sym('I')
C = MX.sym('C')
rE = MX.sym('rE')

# RK4_fun = Function('RK4_fun', [t, x, I, C, rE], [RungeKutta_4(t, x, I, C, rE)])

Delta = 2*60*60
Nt = 40*60*60
Nsim = 5*24*60*60

# Using shorter sampling times at time steps where the soil moisture is very high
# def sim(T, initHead, irrigAmnt, cropCoeff, refEvap):
#     interHead = initHead
#     samplingTimeInternal = 60
#     y = []
#     for i in range(int(T/Delta)):
#         ui = irrigAmnt[i]
#         for j in range(int(Delta/samplingTimeInternal)):
#             interRes = RungeKutta_4(samplingTimeInternal, interHead, ui, cropCoeff, refEvap)
#             interHead = interRes
#         y = vertcat(y,vg.volumetricMoisture(interHead[19]))
#     return y
# samplingTimeInternal = 120
# internalTimeSteps = int(Delta/samplingTimeInternal)

p = vertcat(I, C, rE)
ode = {'x': x, 'p': p, 'ode':ODE(x, p[0], p[1], p[2])}
opts = {'tf': samplingTimeInternal, 'regularity_check': True}
F = integrator('F', 'cvodes', ode, opts)

def sim_2h_MX(head, irrigAmount, cropCoeff, refEvap):
    for t in range(internalTimeSteps):
        r = F(x0 = head, p = vertcat(irrigAmount, cropCoeff, refEvap))
        head = r["xf"]
    error = DM.zeros(26)
    for i in range(26):
        error[i] = np.random.choice([-1,1]) * 0.02*np.random.random_sample()*head[i]
    return head+error



# def sim_2h(head, irrigAmount, cropCoeff, refEvap, soilPars):
#     x_i = head
#     for j in range(int(Delta/samplingTimeInternal)):
#         sol = integrate.solve_ivp(ODE, args = (u_i, cropCoeff, refEvap, soilPars),
#                               t_span=[0,samplingTimeInternal],y0=tuple(x_i),method='RK45') # + Precip[i]
    
#         x_i = sol.y[:,-1]
#         # u_i = 0  # Precip[i]
    
#     y_i = vg.volumetricMoisture(x_i[y_index], soilPars)
    
#     return x_i, y_i
# %% Define functions that defines the Zone MPC problem with given conditions 
Nu = 1
Ny = 1


def obj(u, y, yz, Q, R):
    J = mtimes((y - yz).T, mtimes(Q, (y - yz))) + R*(1-u)**2
    return J

# def obj(u, yl, yu, Ql, Qu, R):
#     J = mtimes(yl.T, mtimes(Ql, yl)) + mtimes(yu.T, mtimes(Qu, yu)) + R*u**2
#     return J

def zoneMPC(ug, yg, lbu, ubu, lby, uby, lbZone, ubZone, 
            Q, R, x0, Nt, Delta):
    
    # u, y and y_z are the variables seen by the optimizer.
    # u is scaled, y and y_z are not scaled
    # Define the UNscaled variabes
    # uc = []
    # # yc = []
    # for i in range(int(Nt/Delta)):
    #     uci = []
    #     for j in range(Nu):
    #         uci = horzcat(uci, u[i,j]*(ubu - lbu) + lbu)
    #         uc = vertcat(uc,uci)
    
    # for i in range(int(Nt/Delta)+1):
    #     yci = []
    #     for j in range(Ny):    
    #         yci = horzcat(yci, (y[i,j] - lby)/(uby - lby))
            # yc = vertcat(yc, yci)
    
    # Define the optimization problem
    W = []
    W0 = []
    W_lb = []
    W_ub = []
    f = 0
    g = []
    g_lb = []
    g_ub = []
    
    y = MX.sym('y_' + str(0))
    W += [y]
    W_lb = vertcat(W_lb,vg.volumetricMoisture(x0[19]))
    W_ub = vertcat(W_ub,vg.volumetricMoisture(x0[19]))
    W0 = vertcat(W0,vg.volumetricMoisture(x0[19]))
    
    # Predict the future outputs using the explictly expressed NN
    xt = x0
    # u = MX.sym("u", int(Nt/Delta))
    # ypred = sim_2h_MX(Delta, xt, u*(ubu - lbu) + lbu, cropCoeff, refEvap)
    for t in range(int(Nt/Delta)):
        # System inputs
        u = MX.sym('u_' + str(t))
        W += [u]
        W_lb = vertcat(W_lb,0) # u is the scaled input 
        W_ub = vertcat(W_ub,1)
        W0 = vertcat(W0,ug)

        # Predicted system outputs
        y = MX.sym('y_' + str(t+1))
        W += [y]
        W_lb = vertcat(W_lb, lby)
        W_ub = vertcat(W_ub, uby)
        W0 = vertcat(W0,yg)
        
        # Zone slack y
        yz = MX.sym('yz_' + str(t))
        W += [yz]
        W_lb = vertcat(W_lb, lbZone)
        W_ub = vertcat(W_ub, ubZone)
        W0 = vertcat(W0, lbZone)
        
        # # Lower Slack
        # W = horzcat(W,yl[t,:])
        # W_lb = vertcat(W_lb, lbslack)
        # W_ub = vertcat(W_ub, ubslack)
        # W0 = vertcat(W0, 0)
        
        # # Upper Slack
        # W = horzcat(W,yu[t,:])
        # W_lb = vertcat(W_lb, lbslack)
        # W_ub = vertcat(W_ub, ubslack)
        # W0 = vertcat(W0, 0)
        
        
        # System model constraints
        temp = sim_2h_MX(xt, u*(ubu - lbu) + lbu, cropCoeff, refEvap)
        g = vertcat(g, vg.volumetricMoisture(temp[19]) - y)
        # g = vertcat(g, (temp*e_a + e_b - lby)/(uby - lby) - yc[t+1,:])
        g_lb = vertcat(g_lb, 0)
        g_ub = vertcat(g_ub, 0)
        xt = temp
        # # Upper Zone constraints
        # g = vertcat(g, y[t+1,:] - yu[t,:])
        # g_lb = vertcat(g_lb, 0)
        # g_ub = vertcat(g_ub, ubZone)
        
        # # Lower Zone constraints
        # g = vertcat(g, y[t+1,:] + yl[t,:])
        # g_lb = vertcat(g_lb, lbZone)
        # g_ub = vertcat(g_ub, 0.5)
        
        # Objective function
        f += obj(u, y, yz, Q, R)
        
    
    return W, W0, W_lb, W_ub, f, g, g_lb, g_ub 

# #%%  Import the weather condition 
# # data = pd.read_csv('C:/Users/zhiyinan/OneDrive - ualberta.ca/Irrigation/crop_model_1D\Weather.csv')
# # refEvap = (data.iloc[10:20,5].values)/1e3/(24*60*60)
# refEvap = np.random.normal(3, 0.2, 10)/1e3/(24*60*60)
# Evap = np.zeros(int(Nsim/Delta))
# for i in range(10):
#     Evap[i*12:i*12+12] = refEvap[i]
# # np.savetxt("C:/Users/zhiyinan/OneDrive - ualberta.ca/Irrigation/crop_model_1D/ZMPC Simulation Results/Evap.txt", Evap)
# # data_hr = pd.read_csv('C:/Users/zhiyinan/OneDrive - ualberta.ca/Irrigation/crop_model_1D\Rain_hourly.csv')
# # precip = data_hr.iloc[240:480,2].values/1e3/(60*60*24)

# Precip = np.zeros((5,12))

# for i in range(5):
#     start = random.randrange(5)
#     end = random.randrange(10)
#     for j in range(start, end):
#         Precip[i, j] = -(random.random()*(20-5)+5)/1e3/(24*60*60)

# Precip = Precip.flatten()
# # Precip = np.zeros(int(Nsim/Delta))

# # for i in range(int(Nsim/Delta)):
# #     Precip[i] = random.choice([0,0,0,0,1]) * (random.random()*(2-0.2)+0.2)/1e3/(24*60*60)

# plt.figure(8)
# plt.step(np.arange(Precip.shape[0]), Precip)
# # np.savetxt("C:/Users/zhiyinan/OneDrive - ualberta.ca/Irrigation/crop_model_1D/ZMPC Simulation Results/Precip.txt", Precip)
# # lzone_store = DM.zeros((int(Nsim/Delta),Ny))
# # uzone_store = DM.zeros((int(Nsim/Delta),Ny))
# # start = time.time()

# %% 
past = 20
future = 1

y_index = 19

lbh = -2
ubh = -0.1
lbu = -5e-06  #  unit [m/s]
ubu = 0  #  unit [m/s]
# ubu = -0.32e-05  #  unit [m/s]
# lb_zone = vg.volumetricMoisture(lbh, soilPars)
# ub_zone = vg.volumetricMoisture(ubh, soilPars)
lb_zone = 0.18*np.ones(int(Nsim/Delta))
ub_zone = 0.23*np.ones(int(Nsim/Delta))
lby = vg.volumetricMoisture(-80)
uby = vg.volumetricMoisture(-0.1)
Q = 4000
# Qu = 900
# QU = 400
R = 100

# Change initial guess (e.g. random number)
ug = 0
yg = 0.2

# Change initial state (e.g. random number)
# u0 = 0
# y0 = 0.24545639881190118
h0 = -0.96
x_i = np.ones(26)*h0
y0 = vg.volumetricMoisture(x_i[y_index])

cropCoeff = 0.88
refEvap = 3.102e-08



# %%

start = time.time()
# u = MX.sym('u',(int(Nt/Delta), Nu))
# y = MX.sym('y',(int(Nt/Delta)+1, Ny))
# yz = MX.sym('yz',(int(Nt/Delta), Ny))
# yl = MX.sym('yl',(int(Nt/Delta), Ny))
# yu = MX.sym('yu',(int(Nt/Delta), Ny))
# sUB = MX.sym('sUB',(int(Nt/Delta), Ny))

u_opt = DM.zeros((int(Nsim/Delta),Nu))
y_opt = DM.zeros((int(Nsim/Delta)+1,Ny))
y_z = DM.zeros((int(Nsim/Delta),Ny))
# y_u = DM.zeros((int(Nsim/Delta),Ny))
# y_l = DM.zeros((int(Nsim/Delta),Ny))
y_opt[0,:] = y0
# y_pred = DM.zeros((int(Nsim/Delta)+1,Ny))
# y_pred[0,:] = y0
# y_e = DM.zeros((int(Nsim/Delta)+2,Ny))
# y_e[0:2] = 0.05
# add = DM.zeros((int(Nsim/Delta),Ny))
e_as = []
e_bs = []
e_a = 1
e_b = 0


for i in range(int(Nsim/Delta)): # int(Nsim/Delta)
    # add[i] = (y_e[i]+y_e[i+1])/2
    W, W0, W_lb, W_ub, f, g, g_lb, g_ub = zoneMPC(ug, yg, 
                                                  lbu, ubu, lby, uby, lb_zone[i], ub_zone[i], 
                                                  Q, R, x_i, Nt, Delta)
    
    nlp = {'x':vertcat(*W),'f':f,'g':g}
    S = nlpsol('S','ipopt',nlp)
    r = S(x0 = W0, lbx = W_lb, ubx = W_ub, lbg = g_lb, ubg = g_ub)   
    u_i = r['x'][Ny:Nu+Ny]*(ubu - lbu) + lbu
    y_opt[i+1,:] = r['x'][Nu+Ny:Nu+Ny*2] 
    y_z[i] = r['x'][Nu+Ny*2:Nu+Ny*3]
    # y_u[i] = r['x'][Nu+Ny*3:Nu+Ny*4]
    u_opt[i,:] = u_i  #*(ubu - lbu) + lbu
    x_i = sim_2h_MX(x_i, u_i, cropCoeff, refEvap)
    
print((time.time()-start)/int(Nsim/Delta))

# u_sol = r['x'][1]*(ubu - lbu) + lbu
# y_sol = r['x'][2]
# for i in range(1,int(Nt/Delta)):
#     u_sol = vertcat(u_sol, r['x'][1+i*4]*(ubu - lbu) + lbu)
#     y_sol = vertcat(y_sol, r['x'][2+i*4])
# y_sol = vertcat(y_sol,r['x'][6])
# y_sol = vertcat(y_sol,r['x'][10])
# y_sol = vertcat(y_sol,r['x'][14])
# y_sol = vertcat(y_sol,r['x'][18])

# u_sol = r['x'][1]
# u_sol = vertcat(u_sol,r['x'][5])
# u_sol = vertcat(u_sol,r['x'][9])
# u_sol = vertcat(u_sol,r['x'][13])
# u_sol = vertcat(u_sol,r['x'][17])
obj_opt = 0
for i in range(int(Nsim/Delta)):
    uc = (u_opt[i] - lbu)/(ubu-lbu)
    obj_opt += obj(uc, y_opt[i+1], y_z[i], Q, R)  
    # obj(u_opt[i], DM(max(y_opt[i] - lb_zone[i],0)), DM(max(ub_zone[i]-y_opt[i],0)), QL, QU,R)

np.savetxt("u_opt_Richards_0.02pn.txt", u_opt)
np.savetxt("y_opt_Richards_0.02pn.txt", y_opt)
# csfont = {'fontname':'Times New Roman'}
# font = font_manager.FontProperties(family='Times New Roman',
#                                     # weight='bold',
#                                     style='normal', size=15)


# tplot = np.linspace(0, Nsim/3600, num = int(Nsim/Delta), endpoint=False)

# # weather + 12 st updates 12.401
# # more rain + 6 st updates 22.987
# # more rain + 12 st updates 10.0228
# plt.figure(4)
# plt.plot(tplot, y_opt[:-1], label = "Richards Eqn Nominal")
# plt.legend(frameon=False, ncol=1)
# plt.xlabel('Time Steps', **csfont, fontsize = 17)
# plt.ylabel('Soil Moisture ($m^3/m^3$)', **csfont, fontsize = 17)


# plt.figure(5)
# plt.step(tplot, u_opt, label = "Richards Eqn Nominal")
# plt.legend(frameon=False, ncol=1)
# plt.xlabel('Time Steps', **csfont, fontsize = 17)
# plt.ylabel('Irrigation (m/s)', **csfont, fontsize = 17)

# plt.figure(6)
# plt.plot(y_e[1:], label = "Rain + 10% P Noise + 2% M Noise - no correction")